import torch
import os
import numpy as np
import taichi as ti
import mcubes
from scipy.spatial import KDTree
import time

# 1. densify grids
# 2. identify grids whose density is larger than some threshold
# 3. filling grids with particles
# 4. identify and fill internal grids


@ti.func
def compute_density(index, pos, opacity, cov, grid_dx):
    gaussian_weight = 0.0
    for i in range(0, 2):
        for j in range(0, 2):
            for k in range(0, 2):
                node_pos = (index + ti.Vector([i, j, k])) * grid_dx
                dist = pos - node_pos
                gaussian_weight += ti.exp(-0.5 * dist.dot(cov @ dist))

    return opacity * gaussian_weight / 8.0


'''
@ti.kernel
def densify_grids(
    init_particles: ti.template(),
    opacity: ti.template(),
    cov_upper: ti.template(),
    grid: ti.template(),
    grid_density: ti.template(),
    grid_dx: float,
):
    for pi in range(init_particles.shape[0]):
        pos = init_particles[pi]
        x = pos[0]
        y = pos[1]
        z = pos[2]
        i = ti.floor(x / grid_dx, dtype=int)
        j = ti.floor(y / grid_dx, dtype=int)
        k = ti.floor(z / grid_dx, dtype=int)
        ti.atomic_add(grid[i, j, k], 1)
        cov = ti.Matrix(
            [
                [cov_upper[pi][0], cov_upper[pi][1], cov_upper[pi][2]],
                [cov_upper[pi][1], cov_upper[pi][3], cov_upper[pi][4]],
                [cov_upper[pi][2], cov_upper[pi][4], cov_upper[pi][5]],
            ]
        )
        sig, Q = ti.sym_eig(cov)
        sig[0] = ti.max(sig[0], 1e-8)
        sig[1] = ti.max(sig[1], 1e-8)
        sig[2] = ti.max(sig[2], 1e-8)
        sig_mat = ti.Matrix(
            [[1.0 / sig[0], 0, 0], [0, 1.0 / sig[1], 0], [0, 0, 1.0 / sig[2]]]
        )
        cov = Q @ sig_mat @ Q.transpose()
        r = 0.0
        for idx in ti.static(range(3)):
            if sig[idx] < 0:
                sig[idx] = ti.sqrt(-sig[idx])
            else:
                sig[idx] = ti.sqrt(sig[idx])

            r = ti.max(r, sig[idx])

        r = ti.ceil(r / grid_dx, dtype=int)
        for dx in range(-r, r + 1):
            for dy in range(-r, r + 1):
                for dz in range(-r, r + 1):
                    if (
                        i + dx >= 0
                        and i + dx < grid_density.shape[0]
                        and j + dy >= 0
                        and j + dy < grid_density.shape[1]
                        and k + dz >= 0
                        and k + dz < grid_density.shape[2]
                    ):
                        density = compute_density(
                            ti.Vector([i + dx, j + dy, k + dz]),
                            pos,
                            opacity[pi],
                            cov,
                            grid_dx,
                        )
                        ti.atomic_add(grid_density[i + dx, j + dy, k + dz], density)


'''






                    
            
@ti.kernel
def densify_grids(
    init_particles: ti.template(),
    opacity: ti.template(),
    cov_upper: ti.template(),
    grid: ti.template(),
    grid_density: ti.template(),
    grid_dx: float,
):
    for pi in range(init_particles.shape[0]):
        pos = init_particles[pi]
        x, y, z = pos[0], pos[1], pos[2]
        i = ti.floor(x / grid_dx, dtype=int)
        j = ti.floor(y / grid_dx, dtype=int)
        k = ti.floor(z / grid_dx, dtype=int)

        if 0 <= i < grid.shape[0] and 0 <= j < grid.shape[1] and 0 <= k < grid.shape[2]:
            ti.atomic_add(grid[i, j, k], 1)

        # Construct full covariance
        cov = ti.Matrix([
            [cov_upper[pi][0], cov_upper[pi][1], cov_upper[pi][2]],
            [cov_upper[pi][1], cov_upper[pi][3], cov_upper[pi][4]],
            [cov_upper[pi][2], cov_upper[pi][4], cov_upper[pi][5]],
        ])

        # Check for NaN or inf
        valid = True
        for p in ti.static(range(3)):
            for q in ti.static(range(3)):
                if ti.math.isnan(cov[p, q]) or ti.math.isinf(cov[p, q]):
                    valid = False
        if not valid:
            print(f'[ERROR] Particle {pi} has invalid value in cov')
            continue  # skip this particle

        # Eigen decomposition
        sig, Q = ti.sym_eig(cov)
        for idx in ti.static(range(3)):
            sig[idx] = ti.max(sig[idx], 1e-8)

        sig_mat = ti.Matrix(
            [[1.0 / sig[0], 0, 0], [0, 1.0 / sig[1], 0], [0, 0, 1.0 / sig[2]]]
        )
        cov = Q @ sig_mat @ Q.transpose()

        # Compute r
        r = 0.0
        for idx in ti.static(range(3)):
            s = ti.sqrt(sig[idx])
            sig[idx] = s
            r = ti.max(r, s)
        r = ti.ceil(r / grid_dx, dtype=int)
        
        max_r = 6  # Safety Upper Bound For r
        if r > max_r:
            print("[WARNING] Large r at particle", pi, "->", r)
            r = max_r

        for dx in range(-r, r + 1):
            for dy in range(-r, r + 1):
                for dz in range(-r, r + 1):
                    gi, gj, gk = i + dx, j + dy, k + dz
                    if (
                        0 <= gi < grid_density.shape[0] and
                        0 <= gj < grid_density.shape[1] and
                        0 <= gk < grid_density.shape[2]
                    ):
                        density = compute_density(
                            ti.Vector([gi, gj, gk]), pos, opacity[pi], cov, grid_dx
                        )
                        ti.atomic_add(grid_density[gi, gj, gk], density)
 

@ti.kernel
def fill_dense_grids(
    grid: ti.template(),
    grid_density: ti.template(),
    grid_dx: float,
    density_thres: float,
    new_particles: ti.template(),
    start_idx: int,
    max_particles_per_cell: int,
) -> int:
    new_start_idx = start_idx
    for i, j, k in grid_density:
        if grid_density[i, j, k] > density_thres:
            if grid[i, j, k] < max_particles_per_cell:
                diff = max_particles_per_cell - grid[i, j, k]
                grid[i, j, k] = max_particles_per_cell
                tmp_start_idx = ti.atomic_add(new_start_idx, diff)

                for index in range(tmp_start_idx, tmp_start_idx + diff):
                    di = ti.random()
                    dj = ti.random()
                    dk = ti.random()
                    new_particles[index] = ti.Vector([i + di, j + dj, k + dk]) * grid_dx

    return new_start_idx


@ti.func
def collision_search(
    grid: ti.template(), grid_density: ti.template(), index, dir_type, size, threshold
) -> bool:
    dir = ti.Vector([0, 0, 0])
    if dir_type == 0:
        dir[0] = 1
    elif dir_type == 1:
        dir[0] = -1
    elif dir_type == 2:
        dir[1] = 1
    elif dir_type == 3:
        dir[1] = -1
    elif dir_type == 4:
        dir[2] = 1
    elif dir_type == 5:
        dir[2] = -1

    flag = False
    index += dir
    i, j, k = index
    while ti.max(i, j, k) < size and ti.min(i, j, k) >= 0:
        if grid_density[index] > threshold:
            flag = True
            break
        index += dir
        i, j, k = index

    return flag


@ti.func
def collision_times(
    grid: ti.template(), grid_density: ti.template(), index, dir_type, size, threshold
) -> int:
    dir = ti.Vector([0, 0, 0])
    times = 0
    if dir_type > 5 or dir_type < 0:
        times = 1
    else:
        if dir_type == 0:
            dir[0] = 1
        elif dir_type == 1:
            dir[0] = -1
        elif dir_type == 2:
            dir[1] = 1
        elif dir_type == 3:
            dir[1] = -1
        elif dir_type == 4:
            dir[2] = 1
        elif dir_type == 5:
            dir[2] = -1

        state = grid[index] > 0
        index += dir
        i, j, k = index
        while ti.max(i, j, k) < size and ti.min(i, j, k) >= 0:
            new_state = grid_density[index] > threshold
            if new_state != state and state == False:
                times += 1
            state = new_state
            index += dir
            i, j, k = index

    return times


@ti.kernel
def internal_filling(
    grid: ti.template(),
    grid_density: ti.template(),
    grid_dx: float,
    new_particles: ti.template(),
    start_idx: int,
    max_particles_per_cell: int,
    exclude_dir: int,
    ray_cast_dir: int,
    threshold: float,
) -> int:
    new_start_idx = start_idx
    for i, j, k in grid:
        if grid[i, j, k] == 0:
            collision_hit = True
            for dir_type in ti.static(range(6)):
                if dir_type != exclude_dir:
                    hit_test = collision_search(
                        grid=grid,
                        grid_density=grid_density,
                        index=ti.Vector([i, j, k]),
                        dir_type=dir_type,
                        size=grid.shape[0],
                        threshold=threshold,
                    )
                    collision_hit = collision_hit and hit_test

            if collision_hit:
                hit_times = collision_times(
                    grid=grid,
                    grid_density=grid_density,
                    index=ti.Vector([i, j, k]),
                    dir_type=ray_cast_dir,
                    size=grid.shape[0],
                    threshold=threshold,
                )

                if ti.math.mod(hit_times, 2) == 1:
                    diff = max_particles_per_cell - grid[i, j, k]
                    grid[i, j, k] = max_particles_per_cell
                    tmp_start_idx = ti.atomic_add(new_start_idx, diff)
                    for index in range(tmp_start_idx, tmp_start_idx + diff):
                        di = ti.random()
                        dj = ti.random()
                        dk = ti.random()
                        new_particles[index] = (
                            ti.Vector([i + di, j + dj, k + dk]) * grid_dx
                        )

    return new_start_idx


@ti.kernel
def assign_particle_to_grid(pos: ti.template(), grid: ti.template(), grid_dx: float):
    for pi in range(pos.shape[0]):
        p = pos[pi]
        i = ti.floor(p[0] / grid_dx, dtype=int)
        j = ti.floor(p[1] / grid_dx, dtype=int)
        k = ti.floor(p[2] / grid_dx, dtype=int)
        ti.atomic_add(grid[i, j, k], 1)


@ti.kernel
def compute_particle_volume(
    pos: ti.template(), grid: ti.template(), particle_vol: ti.template(), grid_dx: float
):
    for pi in range(pos.shape[0]):
        p = pos[pi]
        i = ti.floor(p[0] / grid_dx, dtype=int)
        j = ti.floor(p[1] / grid_dx, dtype=int)
        k = ti.floor(p[2] / grid_dx, dtype=int)
        particle_vol[pi] = (grid_dx * grid_dx * grid_dx) / grid[i, j, k]


@ti.kernel
def assign_particle_to_grid(
    pos: ti.template(),
    grid: ti.template(),
    grid_dx: float,
):
    for pi in range(pos.shape[0]):
        p = pos[pi]
        i = ti.floor(p[0] / grid_dx, dtype=int)
        j = ti.floor(p[1] / grid_dx, dtype=int)
        k = ti.floor(p[2] / grid_dx, dtype=int)
        ti.atomic_add(grid[i, j, k], 1)


def get_particle_volume(pos, grid_n: int, grid_dx: float, unifrom: bool = False):
    ti_pos = ti.Vector.field(n=3, dtype=float, shape=pos.shape[0])
    ti_pos.from_torch(pos.reshape(-1, 3))

    grid = ti.field(dtype=int, shape=(grid_n, grid_n, grid_n))
    particle_vol = ti.field(dtype=float, shape=pos.shape[0])

    assign_particle_to_grid(ti_pos, grid, grid_dx)
    compute_particle_volume(ti_pos, grid, particle_vol, grid_dx)

    if unifrom:
        vol = particle_vol.to_torch()
        vol = torch.mean(vol).repeat(pos.shape[0])
        return vol
    else:
        return particle_vol.to_torch()


def fill_particles(
    pos,
    opacity,
    cov,
    grid_n: int,
    max_samples: int,
    grid_dx: float,
    density_thres=2.0,
    search_thres=1.0,
    max_particles_per_cell=1,
    search_exclude_dir=5,
    ray_cast_dir=4,
    boundary: list = None,
    smooth: bool = False,
):
    pos_clone = pos.clone()
    if boundary is not None:
        assert len(boundary) == 6
        mask = torch.ones(pos_clone.shape[0], dtype=torch.bool).cuda()
        max_diff = 0.0
        for i in range(3):
            mask = torch.logical_and(mask, pos_clone[:, i] > boundary[2 * i])
            mask = torch.logical_and(mask, pos_clone[:, i] < boundary[2 * i + 1])
            max_diff = max(max_diff, boundary[2 * i + 1] - boundary[2 * i])

        pos = pos[mask]
        opacity = opacity[mask]
        cov = cov[mask]

        grid_dx = max_diff / grid_n
        new_origin = torch.tensor([boundary[0], boundary[2], boundary[4]]).cuda()
        pos = pos - new_origin

    ti_pos = ti.Vector.field(n=3, dtype=float, shape=pos.shape[0])
    ti_opacity = ti.field(dtype=float, shape=opacity.shape[0])
    ti_cov = ti.Vector.field(n=6, dtype=float, shape=cov.shape[0])
    ti_pos.from_torch(pos.reshape(-1, 3))
    ti_opacity.from_torch(opacity.reshape(-1))
    ti_cov.from_torch(cov.reshape(-1, 6))

    grid = ti.field(dtype=int, shape=(grid_n, grid_n, grid_n))
    grid_density = ti.field(dtype=float, shape=(grid_n, grid_n, grid_n))
    particles = ti.Vector.field(n=3, dtype=float, shape=max_samples)
    fill_num = 0

    # compute density_field
    densify_grids(ti_pos, ti_opacity, ti_cov, grid, grid_density, grid_dx)

    # fill dense grids
    fill_num = fill_dense_grids(
        grid,
        grid_density,
        grid_dx,
        density_thres,
        particles,
        0,
        max_particles_per_cell,
    )
    print("after dense grids: ", fill_num)

    # smooth density_field
    if smooth:
        df = grid_density.to_numpy()
        smoothed_df = mcubes.smooth(df, method="constrained", max_iters=500).astype(
            np.float32
        )
        grid_density.from_numpy(smoothed_df)
        print("smooth finished")

    # fill internal grids
    fill_num = internal_filling(
        grid,
        grid_density,
        grid_dx,
        particles,
        fill_num,
        max_particles_per_cell,
        exclude_dir=search_exclude_dir,  # 0: x, 1: -x, 2: y, 3: -y, 4: z, 5: -z direction
        ray_cast_dir=ray_cast_dir,  # 0: x, 1: -x, 2: y, 3: -y, 4: z, 5: -z direction
        threshold=search_thres,
    )
    print("after internal grids: ", fill_num)

    # put new particles together with original particles
    particles_tensor = particles.to_torch()[:fill_num].cuda()
    if boundary is not None:
        particles_tensor = particles_tensor + new_origin
    particles_tensor = torch.cat([pos_clone, particles_tensor], dim=0)

    return particles_tensor


@ti.kernel
def get_attr_from_closest(
    ti_pos: ti.template(),
    ti_shs: ti.template(),
    ti_opacity: ti.template(),
    ti_cov: ti.template(),
    ti_new_pos: ti.template(),
    ti_new_shs: ti.template(),
    ti_new_opacity: ti.template(),
    ti_new_cov: ti.template(),
):
    for pi in range(ti_new_pos.shape[0]):
        p = ti_new_pos[pi]
        min_dist = 1e10
        min_idx = -1
        for pj in range(ti_pos.shape[0]):
            dist = (p - ti_pos[pj]).norm()
            if dist < min_dist:
                min_dist = dist
                min_idx = pj
        ti_new_shs[pi] = ti_shs[min_idx]
        ti_new_opacity[pi] = ti_opacity[min_idx]
        ti_new_cov[pi] = ti_cov[min_idx]
        
@ti.kernel
def get_attr_from_closest_vec(
    ti_pos: ti.template(),
    ti_attr: ti.template(),
    ti_new_pos: ti.template(),
    ti_new_attr: ti.template(),
):
    for pi in range(ti_new_pos.shape[0]):
        p = ti_new_pos[pi]
        min_dist = 1e10
        min_idx = -1
        for pj in range(ti_pos.shape[0]):
            dist = (p - ti_pos[pj]).norm()
            if dist < min_dist:
                min_dist = dist
                min_idx = pj
        ti_new_attr[pi] = ti_attr[min_idx]

'''
def init_filled_particles(pos, shs, opacity, scaling, rotation, new_pos):
    """
    Initializes attributes for new particles by finding the nearest original particle
    and copying its attributes. This version only works with the core GS attributes.
    """
    # Ensure shs is in a flat [N, C] format for the kernel
    original_shs_shape = shs.shape
    shs_flat = shs.reshape(pos.shape[0], -1)

    # Prepare Taichi fields for original particle attributes
    ti_pos = ti.Vector.field(n=3, dtype=ti.f32, shape=pos.shape[0])
    ti_shs = ti.Vector.field(n=shs_flat.shape[1], dtype=ti.f32, shape=shs_flat.shape[0])
    ti_opacity = ti.field(dtype=ti.f32, shape=opacity.shape[0])
    ti_scaling = ti.Vector.field(n=scaling.shape[1], dtype=ti.f32, shape=scaling.shape[0])
    ti_rotation = ti.Vector.field(n=rotation.shape[1], dtype=ti.f32, shape=rotation.shape[0])

    ti_pos.from_torch(pos)
    ti_shs.from_torch(shs_flat)
    ti_opacity.from_torch(opacity.squeeze()) # Squeeze to be safe
    ti_scaling.from_torch(scaling)
    ti_rotation.from_torch(rotation)

    # Prepare Taichi fields for the new particles to be filled
    ti_new_pos = ti.Vector.field(n=3, dtype=ti.f32, shape=new_pos.shape[0])
    ti_new_shs = ti.Vector.field(n=shs_flat.shape[1], dtype=ti.f32, shape=new_pos.shape[0])
    ti_new_opacity = ti.field(dtype=ti.f32, shape=new_pos.shape[0])
    ti_new_scaling = ti.Vector.field(n=scaling.shape[1], dtype=ti.f32, shape=new_pos.shape[0])
    ti_new_rotation = ti.Vector.field(n=rotation.shape[1], dtype=ti.f32, shape=new_pos.shape[0])

    ti_new_pos.from_torch(new_pos)

    # Find nearest original particle and copy all attributes
    get_attr_from_closest_vec(ti_pos, ti_shs, ti_new_pos, ti_new_shs)
    # Note: We need a scalar version for opacity or adapt get_attr_from_closest_vec
    # Let's create a simple kernel for scalar attributes
    @ti.kernel
    def get_scalar_attr_from_closest_vec(
        p_pos: ti.template(), p_attr: ti.template(),
        n_pos: ti.template(), n_attr: ti.template()):
        for i in range(n_pos.shape[0]):
            p = n_pos[i]
            min_dist_sq = 1e10
            min_idx = -1
            for j in range(p_pos.shape[0]):
                dist_sq = (p - p_pos[j]).norm_sqr()
                if dist_sq < min_dist_sq:
                    min_dist_sq = dist_sq
                    min_idx = j
            if min_idx != -1:
                n_attr[i] = p_attr[min_idx]

    get_scalar_attr_from_closest_vec(ti_pos, ti_opacity, ti_new_pos, ti_new_opacity)
    get_attr_from_closest_vec(ti_pos, ti_scaling, ti_new_pos, ti_new_scaling)
    get_attr_from_closest_vec(ti_pos, ti_rotation, ti_new_pos, ti_new_rotation)

    # Convert back to torch tensors
    new_shs_flat = ti_new_shs.to_torch()
    new_opacity = ti_new_opacity.to_torch().unsqueeze(1) # Ensure shape is [M, 1]
    new_scaling = ti_new_scaling.to_torch()
    new_rotation = ti_new_rotation.to_torch()

    # Combine original and new attributes
    final_shs = torch.cat([shs_flat, new_shs_flat], dim=0).view(-1, *original_shs_shape[1:])
    final_opacity = torch.cat([opacity, new_opacity], dim=0)
    final_scaling = torch.cat([scaling, new_scaling], dim=0)
    final_rotation = torch.cat([rotation, new_rotation], dim=0)

    # Return only the combined tensors needed for saving
    return final_shs, final_opacity, final_scaling, final_rotation
'''    
    
    
def init_filled_particles(pos, shs, opacity, scaling, rotation, new_pos):
    """
    Initializes attributes for new particles by finding the nearest original particle
    and copying its attributes.

    This version uses a highly efficient KDTree for nearest-neighbor search,
    avoiding the O(N*M) complexity of the previous brute-force method.
    """
    print(f"[INFO] Initializing attributes for {new_pos.shape[0]} new particles...")
    
    # Step 1: Build the KDTree on the CPU from the original particle positions.
    # Data needs to be moved to CPU for this operation.
    print("[INFO] Building KDTree for nearest neighbor search...")
    pos_cpu = pos.detach().cpu().numpy()
    kdtree = KDTree(pos_cpu)

    # Step 2: Query the KDTree to find the nearest original particle for each new particle.
    # This is extremely fast. It returns distances and indices. We only need indices.
    print("[INFO] Querying KDTree...")
    new_pos_cpu = new_pos.detach().cpu().numpy()
    _, nearest_indices = kdtree.query(new_pos_cpu, k=1)
    
    # Convert indices to a PyTorch tensor and move to GPU for efficient gathering.
    nearest_indices = torch.from_numpy(nearest_indices).long().cuda()

    # Step 3: Use the indices to "gather" attributes from the original tensors.
    # This is a highly optimized parallel operation on the GPU.
    print("[INFO] Gathering attributes from nearest neighbors...")
    new_shs = torch.index_select(shs, 0, nearest_indices)
    new_opacity = torch.index_select(opacity, 0, nearest_indices)
    new_scaling = torch.index_select(scaling, 0, nearest_indices)
    new_rotation = torch.index_select(rotation, 0, nearest_indices)
    
    #new_shs[:, 1:, :] = 0
    #new_opacity[:] = -100.0        ############ set as 0

    # Step 4: Combine the original attributes with the newly created attributes.
    print("[INFO] Concatenating final attribute tensors...")
    final_pos = torch.cat([pos, new_pos], dim=0)
    final_shs = torch.cat([shs, new_shs], dim=0)
    final_opacity = torch.cat([opacity, new_opacity], dim=0)
    final_scaling = torch.cat([scaling, new_scaling], dim=0)
    final_rotation = torch.cat([rotation, new_rotation], dim=0)

    # We now return the full, combined tensors. The main script will handle them.
    # Note: We return final_pos as well for consistency, though the main script already has it.
    return final_pos, final_shs, final_opacity, final_scaling, final_rotation
